#include "TrainHist.h"

using namespace std;
using namespace cv;
using namespace boost::filesystem;
typedef vector<path> vec_path;

////////////////////////////////////////////////////////////////////////////////
//	PUBLIC METHODS
////////////////////////////////////////////////////////////////////////////////
TrainHist::TrainHist () {
	_training_results = "hist_training_results.xml";
}

bool TrainHist::train (std::string dir,
						std::vector<std::string> separators,
						std::vector<std::string> names,
						float* interval) {
	bool res = train_directory (dir, separators, names, interval);
	return res;
}

bool TrainHist::save_training (string path, bool override) {
	cout << "Saving " << path2name(path) << endl;
	return save_descriptor (path, _descriptor, override);
}

bool TrainHist::load_training (string path) {
	string name = path2name(path);
	string result_path = path2rest(path) + _training_results;
	FileStorage fs;
	if (!fs.open(result_path, FileStorage::READ)){
		cout << "load_training : couldn't open " << path << endl;
		return false;
	}
	Mat desc;
	fs[name] >> desc;
	fs.release();
	_descriptor = merge (_descriptor, desc);
	cout << "Loaded training results from " << result_path << endl;
	return true;
}

void TrainHist::clear_training () {
	_descriptor = Mat();
}
////////////////////////////////////////////////////////////////////////////////
//	PRIVATE METHODS
////////////////////////////////////////////////////////////////////////////////
bool TrainHist::train_directory (std::string dir,
										 							std::vector<std::string> separators,
 																	std::vector<std::string> names,
										 							float* interval) {
	// Check if dir is a directory
	if (!is_directory (dir)) {
		cout << dir << " not a directory." << endl;
		return false;
	}
	// if this directory has already been trained, load the results
	if (already_trained(dir)) {
		cout << dir << " already trained." << endl;
		load_training (dir);
	} else {
    vec_path v; // so we can sort them later
    copy(directory_iterator(dir), directory_iterator(), back_inserter(v));
    sort(v.begin(), v.end());	// sort, directory iteration is not ordered
		// compute number of files
    int n = v.size ();
    int start = n*interval[0];
    int end = n*interval[1];
    for (int i = start; i < end; ++i) {
			if (is_directory(v[i].string()))
				train_directory (v[i].string(), separators, names, interval);
			// must match main/first file format
			else if (is_right (v[i].string(), separators[0], names[0]))
				train_file (v[i].string(), separators, names);
		}
		save_training (dir, false);
	}
	return true;
}

bool TrainHist::train_file (string path, vector<string> separators,
 																					vector<string> names) {
	// if this directory has already been trained, load the results
	if (already_trained(path)) {
		cout << path << " already trained." << endl;
		load_training (path);
	} else {
		Mat descriptor = extract (path, separators, names);
		_descriptor = merge (_descriptor, descriptor);
		save_descriptor (path, descriptor, false);
	}
	return true;
}

bool TrainHist::save_descriptor (string path, Mat descriptor, bool override) {
	string name = path2name(path);
	string result_path = path2rest(path) + _training_results;
	FileStorage fw;
	int mode;
	override? (mode=FileStorage::WRITE) : (mode=FileStorage::APPEND);

	if (!fw.open(result_path, mode)){
		cout << "save_training : couldn't open " << path << endl;
		return false;
	}
	fw << name << descriptor;
	fw.release();
	return true;
}

bool TrainHist::already_trained (string path) {
	bool res = false;
	string name = path2name(path);
	string result_path = path2rest(path) + _training_results;
	FileStorage fs;
	// if no training file, not trained
	if (!fs.open(result_path, FileStorage::READ)) {
		return res;
	}
	// if name doesn't match anything, not trained
	if (fs[name].type() == FileNode::NONE) {
		res = false;
	}
	else {
		res = true;
	}
	fs.release();
	return res;
}

string TrainHist::path2name (string path) {
	vector<string> splitted;
	boost::split (splitted, path, boost::algorithm::is_any_of("/"));
	string filename;
	if (!splitted[splitted.size()-1].empty())
		filename = splitted[splitted.size()-1];
	else
		filename = splitted[splitted.size()-2];
	boost::split (splitted, filename, boost::algorithm::is_any_of("."));
	return splitted[0];
}

string TrainHist::path2rest (string path) {
	return path2rest (path, "/");
}

string TrainHist::path2rest (string path, string separator) {
	vector<string> splitted;
	boost::split (splitted, path, boost::algorithm::is_any_of(separator));
	int end;
	if (!splitted[splitted.size()-1].empty())
		end = splitted.size()-1;
	else
		end = splitted.size()-2;
	string rest;
	for (int i=0; i<end; ++i) // all but last part
		rest += splitted[i] + separator;
	return rest;
}

bool TrainHist::is_right (string file, string separator, string name) {
	vector<string> splitted;
	boost::split (splitted, file, boost::algorithm::is_any_of(separator));
	if (splitted[splitted.size()-1] == name)
		return true;
	else
		return false;
}
////////////////////////////////////////////////////////////////////////////////
//  Modify these functions depending on the type of descriptor. The rest should
//	be generic as long as the descriptor can be written in a cv::Mat class.
////////////////////////////////////////////////////////////////////////////////
Mat TrainHist::extract (string path, vector<string> separators,
 																		 vector<string> names) {
	//return Mat::eye(3, 3, 1);
	string base = path2rest(path, separators[0]);
	string imgpath = base + names[0];
	base = path2rest(path, separators[1]);
	string maskpath = base + names[1];
  Mat img = imread(imgpath, CV_LOAD_IMAGE_COLOR);
 	Mat mask = imread(maskpath, CV_LOAD_IMAGE_GRAYSCALE);
  Mat desc = hist.classic_histogram (img, mask, 3);
  //cout << desc << endl;
  return desc;
}

Mat TrainHist::merge (Mat descriptor1, Mat descriptor2) {
	Mat desc;
	if (descriptor1.data)
		desc = descriptor1 + descriptor2;
	else
		desc = descriptor2;
	normalize (desc, desc, 0, 1, NORM_MINMAX, -1, Mat() );
	return desc;
}

